import os
import h5py
import numpy as np
import argparse
from IPython import embed
from pytorch3d.loss import chamfer_distance
import torch
from pytorch3d.ops import cubify, sample_points_from_meshes
from scipy.optimize import linear_sum_assignment

parser = argparse.ArgumentParser()

parser.add_argument('--path', type = str, required = True,
                    help='path to your trained DALL-E')
parser.add_argument('--category', type = str, default = 'noneed',
                    help='path to your trained DALL-E')
parser.add_argument('--emd', type = bool, default = True,
                    help='path to your trained DALL-E')
parser.add_argument('--ori', type = bool, default = False,
                    help='path to your trained DALL-E')


args = parser.parse_args()

ori_shapes = torch.load(os.path.join('/home/tiangel/PVD/output/test_completion/', args.path, 'recon_gt.pth'))
rec_shapes = torch.load(os.path.join('/home/tiangel/PVD/output/test_completion/', args.path, 'recon_samples.pth'))
torch.set_printoptions(precision=7)

ori_shapes = ori_shapes.reshape(-1,2048,3)
rec_shapes = rec_shapes.reshape(-1,2048,3)
cd_dis = chamfer_distance(ori_shapes.cuda(), rec_shapes.cuda())[0]
print('cd_dis:', cd_dis)

# chamfer_distance(ori_shapes[0].unsqueeze(0).cuda(), rec_shapes[0].unsqueeze(0).cuda())[0]
if args.emd:
    emd_dis = []
    dim = 2048
    for i in range(ori_shapes.shape[0]):
    # for i in range(1):
        print('emd',i)
        q1 = ori_shapes[i].numpy()
        q2 = rec_shapes[i].numpy()
        t1 = np.repeat(q1,dim,axis=0).reshape(dim,dim,3)
        t2 = np.swapaxes(np.repeat(q2,dim,axis=0).reshape(dim,dim,3), 0, 1)
        diff = t1-t2
        matrix = diff[:,:,0]*diff[:,:,0]+diff[:,:,1]*diff[:,:,1]+diff[:,:,2]*diff[:,:,2]
        row_ind, col_ind = linear_sum_assignment(matrix)
        diff2=q1 - q2[col_ind]
        # diff2 = q1 - q2
        emd_dis.append(np.mean(np.sqrt(diff2[:,0]*diff2[:,0]+diff2[:,1]*diff2[:,1]+diff2[:,2]*diff2[:,2])))
print('emd_dis:', np.mean(np.array(emd_dis)))

embed()
exit()


ours_pc_list = []
for i in range(ours_shapes.shape[0]):
    m1 = cubify(torch.Tensor(ours_shapes[i]).unsqueeze(0),0.5)
    p1 = sample_points_from_meshes(m1)
    # ours_pc_list.append(np.expand_dims(p1,0))
    ours_pc_list.append(p1)
ours_pc = np.vstack(ours_pc_list)

target_pc_list = []
for i in range(ours_shapes.shape[0]):
    m1 = cubify(torch.Tensor(ours_shapes[i]).unsqueeze(0),0.5)
    p1 = sample_points_from_meshes(m1)
    # target_pc_list.append(np.expand_dims(p1,0))
    target_pc_list.append(p1)
target_pc = np.vstack(target_pc_list)

cd_dis = torch.sqrt(chamfer_distance(torch.Tensor(ours_pc).cuda(), torch.Tensor(target_pc).cuda())[0])
print('cd_dis:', cd_dis)
if args.emd:
    emd_dis = []
    dim = 10000
    for i in range(ours_pc.shape[0]):
        print('emd',i)
        q1 = ours_pc[i]
        q2 = target_pc[i]
        t1 = np.repeat(q1,dim,axis=0).reshape(dim,dim,3)
        t2 = np.swapaxes(np.repeat(q2,dim,axis=0).reshape(dim,dim,3), 0, 1)
        diff = t1-t2
        matrix = diff[:,:,0]*diff[:,:,0]+diff[:,:,1]*diff[:,:,1]+diff[:,:,2]*diff[:,:,2]
        row_ind, col_ind = linear_sum_assignment(matrix)
        diff2=q1 - q2[col_ind]
        # diff2 = q1 - q2
        emd_dis.append(np.mean(np.sqrt(diff2[:,0]*diff2[:,0]+diff2[:,1]*diff2[:,1]+diff2[:,2]*diff2[:,2])))
print('emd_dis:', np.mean(np.array(emd_dis)))

embed()
